{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 蒸馏 DeepSeek R1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 复用 Eval.py 进行推理\n",
    "\n",
    "其实蒸馏和 Eval 一样，就是批量推理，拿到思维链，然后使用 Rejection Sampling 选择正确的答案。\n",
    "\n",
    "这次因为成本原因，每条训练集只推理一次。\n",
    "\n",
    "```bash\n",
    "CURATOR_VIEWER=1 python eval.py --provider qianfan --data_path data/train.jsonl --model_name deepseek-r1 --temperature 0.6 --max_tokens 8192 --is_reasoning True --output_prefix distill --max_concurrent_requests 20\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用服务商的批量推理功能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from eval import Reasoner, EvalResult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"custom_id\": \"request-0\",\n",
      "    \"body\": {\n",
      "        \"messages\": [\n",
      "            {\n",
      "                \"role\": \"user\",\n",
      "                \"content\": \"Using the numbers 29, 11, 27, 10, create an equation that equals 524. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>.\"\n",
      "            }\n",
      "        ],\n",
      "        \"max_tokens\": 8192,\n",
      "        \"temperature \": 0.6\n",
      "    }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "llm = Reasoner(backend=\"vllm\", model_name=\"deepseek-r1\", is_reasoning=True)\n",
    "\n",
    "data = []\n",
    "input_data = {}\n",
    "with open(\"data/train.jsonl\", \"r\") as f:\n",
    "    for i, line in enumerate(f):\n",
    "        d = json.loads(line)\n",
    "\n",
    "        messages = llm.prompt(d)\n",
    "        data.append({\"custom_id\": f\"request-{i}\", \"body\": {\"messages\": messages,\"max_tokens\": 8192 ,\"temperature \": 0.6}})\n",
    "        input_data[f\"request-{i}\"] = d\n",
    "\n",
    "print(json.dumps(data[0], indent=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/distill_batch_requests.jsonl\", \"w\") as f:\n",
    "    for d in data:\n",
    "        f.write(json.dumps(d, ensure_ascii=False) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'target': 67,\n",
       " 'numbers': [93, 14, 3, 98],\n",
       " 'prompt': 'Using the numbers 93, 14, 3, 98, create an equation that equals 67. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>.',\n",
       " 'completion': \"<think>Okay, let's see. I need to use all the numbers 93, 14, 3, and 98 to make an equation that equals 67. Hmm. Let's start by thinking about the operations. Maybe subtract some numbers first. Since 98 is a big number, if I subtract something from it, maybe that's part of the solution. Let me try 98 minus something. Let's see, 98 minus 14 is 84, but that's still bigger than 67. If I subtract 93 from 98, that's 5. Then maybe add 14 and 3? Wait, but I have to use all numbers. Let me write down the numbers: 93, 14, 3, 98.\\n\\nAnother approach: maybe combine some numbers with multiplication or division. Let's see, 14 divided by 3 is roughly 4.666, but not sure if that helps. What if I do 93 minus something? 93 minus 14 is 79, which is still higher than 67. Then maybe subtract 3 and 98? Wait, but I can't do that because all numbers must be used once. Wait, maybe I need to group them into different operations.\\n\\nWait, what if I take 98 and subtract 3, which is 95. Then subtract something else. If I subtract 93 from 95, that's 2. Then add 14? 2 +14 is 16. No, that's not 67. Hmm.\\n\\nAlternatively, maybe 98 divided by something. 98 divided by 14 is 7. Then 7 multiplied by 3 is 21. Then 93 minus 21 is 72. Not 67. Close though. Maybe different operations.\\n\\nWhat about 93 minus (14 + 3) which is 93 -17=76. Then 98 -76=22. Not helpful. Wait, but I need to use all numbers. So maybe 98 - (93 - 14) /3. Let's check: 93-14=79. 79 divided by 3 is about 26.333. Then 98-26.333=71.666, which is not 67.\\n\\nWait, maybe 93 - (98 - 14)/3. Let's compute that. 98-14=84. 84 divided by 3 is 28. 93-28=65. Close to 67 but not quite. Hmm.\\n\\nWhat if I rearrange the operations. Let's try 14 multiplied by 3 is 42. Then 98 minus 42 is 56. Then 93 minus 56 is 37. Not helpful. Or 93 minus 98 is -5. Then -5 +14=9, then 9*3=27. Nope.\\n\\nWait, maybe (98 - 93) is 5. Then 5 multiplied by 14 is 70. Then 70 -3=67. Oh! That works. Let's check: (98 -93) =5. Then 5 *14=70. Then 70 -3=67. But wait, all numbers must be used. Let's see: 98,93,14,3. Yes, all are used once. So the equation would be (98 -93) *14 -3 =67. Let me verify that:\\n\\n98 minus93 is 5, multiply by14 gives 70, subtract3 gives 67. Yes! That works. So the equation is (98 -93) *14 -3 =67.\\n</think>\\n\\n<answer>(98 - 93) * 14 - 3 = 67</answer>\",\n",
       " 'reasoning': \"Okay, let's see. I need to use all the numbers 93, 14, 3, and 98 to make an equation that equals 67. Hmm. Let's start by thinking about the operations. Maybe subtract some numbers first. Since 98 is a big number, if I subtract something from it, maybe that's part of the solution. Let me try 98 minus something. Let's see, 98 minus 14 is 84, but that's still bigger than 67. If I subtract 93 from 98, that's 5. Then maybe add 14 and 3? Wait, but I have to use all numbers. Let me write down the numbers: 93, 14, 3, 98.\\n\\nAnother approach: maybe combine some numbers with multiplication or division. Let's see, 14 divided by 3 is roughly 4.666, but not sure if that helps. What if I do 93 minus something? 93 minus 14 is 79, which is still higher than 67. Then maybe subtract 3 and 98? Wait, but I can't do that because all numbers must be used once. Wait, maybe I need to group them into different operations.\\n\\nWait, what if I take 98 and subtract 3, which is 95. Then subtract something else. If I subtract 93 from 95, that's 2. Then add 14? 2 +14 is 16. No, that's not 67. Hmm.\\n\\nAlternatively, maybe 98 divided by something. 98 divided by 14 is 7. Then 7 multiplied by 3 is 21. Then 93 minus 21 is 72. Not 67. Close though. Maybe different operations.\\n\\nWhat about 93 minus (14 + 3) which is 93 -17=76. Then 98 -76=22. Not helpful. Wait, but I need to use all numbers. So maybe 98 - (93 - 14) /3. Let's check: 93-14=79. 79 divided by 3 is about 26.333. Then 98-26.333=71.666, which is not 67.\\n\\nWait, maybe 93 - (98 - 14)/3. Let's compute that. 98-14=84. 84 divided by 3 is 28. 93-28=65. Close to 67 but not quite. Hmm.\\n\\nWhat if I rearrange the operations. Let's try 14 multiplied by 3 is 42. Then 98 minus 42 is 56. Then 93 minus 56 is 37. Not helpful. Or 93 minus 98 is -5. Then -5 +14=9, then 9*3=27. Nope.\\n\\nWait, maybe (98 - 93) is 5. Then 5 multiplied by 14 is 70. Then 70 -3=67. Oh! That works. Let's check: (98 -93) =5. Then 5 *14=70. Then 70 -3=67. But wait, all numbers must be used. Let's see: 98,93,14,3. Yes, all are used once. So the equation would be (98 -93) *14 -3 =67. Let me verify that:\\n\\n98 minus93 is 5, multiply by14 gives 70, subtract3 gives 67. Yes! That works. So the equation is (98 -93) *14 -3 =67.\\n\",\n",
       " 'solution': '(98 - 93) * 14 - 3 = 67',\n",
       " 'correct': True,\n",
       " 'correct_reason': 'Solution is correct'}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# results.json 放到 output 下\n",
    "\n",
    "results = []\n",
    "\n",
    "with open(\"output/results.jsonl\", \"r\") as f:\n",
    "    for line in f:\n",
    "        d = json.loads(line)\n",
    "        custom_id = d['custom_id']\n",
    "        input = input_data[custom_id]\n",
    "        response = d.get(\"response\", {}).get(\"body\")\n",
    "        if not response:\n",
    "            results.append({\n",
    "                \"correct\": False,\n",
    "                \"correct_reason\": \"call llm failed\",\n",
    "                \"target\": input[\"target\"],\n",
    "                \"numbers\": input[\"numbers\"],\n",
    "                \"prompt\": \"\",\n",
    "                \"completion\": \"\",\n",
    "                \"reasoning\": \"\",\n",
    "            })\n",
    "        results.append(llm.parse(input, response)[0])\n",
    "\n",
    "results[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 886/1000 (88.60%)\n"
     ]
    }
   ],
   "source": [
    "total_count = len(results)\n",
    "correct_count = sum([1 for r in results if r[\"correct\"]])\n",
    "accuracy = correct_count / total_count\n",
    "print(f\"Accuracy: {correct_count}/{total_count} ({accuracy:.2%})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SFT Data: 712\n"
     ]
    }
   ],
   "source": [
    "no_reasoning_llm = Reasoner(backend=\"vllm\", model_name=\"deepseek-r1\", is_reasoning=False)\n",
    "\n",
    "sft_data = []\n",
    "sft_data_reasoning = []\n",
    "\n",
    "for d in results:\n",
    "    if not d[\"correct\"]:\n",
    "        # 如果推理不正确，则不加入 SFT 数据\n",
    "        continue\n",
    "    if len(d[\"completion\"]) >= 7500:\n",
    "        # 如果 completion 超过 7500，则去除\n",
    "        # 选择 7500 是需要保证总长度不超过 8192\n",
    "        continue\n",
    "    messages = no_reasoning_llm.prompt(d)\n",
    "    messages.append({\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": d[\"completion\"]\n",
    "    })\n",
    "    sft_data.append({\n",
    "        \"messages\": messages,\n",
    "    })\n",
    "    messages_reasoning = llm.prompt(d)\n",
    "    messages_reasoning.append({\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": d[\"completion\"]\n",
    "    })\n",
    "    sft_data_reasoning.append({\n",
    "        \"messages\": messages_reasoning,\n",
    "    })\n",
    "\n",
    "print(f\"SFT Data: {len(sft_data)}\")\n",
    "with open(\"data/train_sft_distill.jsonl\", \"w\") as f:\n",
    "    for d in sft_data:\n",
    "        f.write(json.dumps(d, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "with open(\"data/train_sft_distill_reasoning.jsonl\", \"w\") as f:\n",
    "    for d in sft_data_reasoning:\n",
    "        f.write(json.dumps(d, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "# todo: 将答案里的等式统一保留或统一去除以保持一致性"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
